feat(mx): MX V2 Support#2700
Draft
KavinKrishnan wants to merge 21 commits into
Draft
Conversation
…le, mixed-TP, MX clients Proposes the next phase of work on top of `nixl_mx` once PrimeIntellect-ai#2389 merges: 1. Phase-1 — six surgical fixes against the in-tree code that close the bug classes we hit during GB200 bring-up (cross-subnet add_remote_agent full-mesh; stale READY peer dedup; heartbeat / STALE-on-shutdown; hardcoded 1200s timeout; non-MLA model guard; HSDP barrier ordering). Line-pinned against HEAD `79ea824d8`. 2. Phase-2 — graduate `src/prime_rl/transport/mx_rendezvous.py` onto NVIDIA's published `modelexpress` Python clients (`MxV2TrainingPublisher` / `MxV2RefitReceiver`). Deletes ~185 LOC of in-tree rendezvous that duplicates the upstream client. Inherits heartbeat + freshest-per-rank dedup + retention + sidecar-filter for free. `NixlAgentWrapper` / `Slot` / `TransportPlan` / `classic_cuda_pool` stay — those are prime-rl specialization. 3. Phase-3 — solves the trainer-side kernel-compile issue surfaced during PrimeIntellect-ai#2389's FP8 cast-pipeline iteration. Trainer publishes HF-raw bytes (kernel-agnostic); inference compiles into its target layout (DeepGemm, cutlass, ...) via a receiver-side scratch-buffer pass. Extends the v2 shape registry with `compile_target` + `compile_metadata`. Heterogeneous fleets (mixed kernels on the same training run) now work without trainer-side branching. 4. Phase-3 also generalizes the v2 sharding metadata to handle mixed-TP/EP via `TargetTPLayout` + multi-source slice discovery in the same machinery NemoRL v2 uses for MoE expert filtering. Pulls heavily on the NemoRL × Dynamo path (NVIDIA, John Thompson) which is already running at 380 Gbps on GB300 RoCE for an 8.82 GB refit — same scratch-buffer + worker-extension-cls pattern this plan adopts. Component + per-refit sequence diagrams (mermaid) included. Estimated ~450 LOC additive across modelexpress + prime-rl for Phases 3-4 (plus the ~400 LOC subtraction from Phase 2). Doc only. Implementation phases sequenced behind the upstream merge of PrimeIntellect-ai#2389.
…ame-rank filter Codifies the two runtime patches we applied on GB200 to unblock Qwen3-30B-A3B bring-up against PR PrimeIntellect-ai#2389, plus a third surgical fix (heartbeat) that closes the stale-READY-after-restart class of bugs. The three changes are intentionally separable from PrimeIntellect-ai#2389 and additive to the existing rendezvous API: 1. **HeartbeatThread on publish()**: When publish() succeeds we start a modelexpress.metadata.heartbeat.HeartbeatThread keyed on (mx_source_id, worker_id, worker_rank). The MX server's reaper then transitions crashed workers to STALE on its own. New `enable_heartbeat: bool = True` field on the dataclass to opt out for tests / one-shots. New `close()` method to stop the thread on graceful shutdown. 2. **Freshest-per-(role, rank) dedup**: New module-level helper `_freshest_per_rank(instances, *, metas)` keeps only the entry with the largest `updated_at` per `worker_rank`. wait_for_peers() and wait_for_all_peers_ready() default to using it; instances missing from `metas` get ts=0 and lose to anything timestamped. Was the second GB200 patch (stale READY from previous run beat fresh READY from the restarted trainer). 3. **same_rank_only filter**: New `_filter_same_rank(instances, *, rank)` helper. wait_for_peers() and wait_for_all_peers_ready() now accept `same_rank_only: bool = False` (off by default for back-compat); when set, only peers with `worker_rank == self.rank` are returned. Required on GCP GB200's multi-NIC fabric where cross-subnet routing fails. `_collect_updated_at(instances)` does the GetMetadata fan-out used by the dedup; failures are mapped to ts=0 so the picker doesn't crash on partial catalog state. Unit tests added under tests/unit/transport/test_mx_rendezvous_phase2.py (11 tests, all green; direct-loads mx_rendezvous.py to bypass prime_rl.transport.__init__'s heavy import chain so the suite runs with only `modelexpress` installed — no docker-compose required). Sub-tests cover: - _filter_same_rank: rank match - _freshest_per_rank: largest updated_at wins; missing-updated_at loses to known-updated_at; lone unknown is kept; stable rank-order - publish() spawns HeartbeatThread with correct kwargs (worker_rank, mx_source_id, worker_id, nixl_manager=None) - close() stops the thread; idempotent - enable_heartbeat=False skips the thread entirely - publish() swallows heartbeat start failures (broken heartbeat must not break rendezvous) - _collect_updated_at returns 0 on RPC failure, 0 on not_found, real value on success No breaking changes to the existing tests/unit/transport/test_mx_rendezvous.py suite (those are integration tests against a docker-compose'd MX server).
…data tagging Extends prime_rl/trainer/models/conversions/ to address the live coworker complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels — the registry currently has only `bf16_cast` and `fp8_128x128`; anything else raises NotImplementedError, and there's no compile_target tag on the publish so wrong-target receivers silently misinterpret bytes. This is the trainer-side half of the design fix; the receiver-side filtering API is already shipped in modelexpress as PR PrimeIntellect-ai#349 (Phase 3a/3b on kavink/post-2389-phase3-4). Once Phase 2 graduation lands on #1, the MxV2TrainingPublisher will read each tensor's resolved ConversionEntry.compile_target + compile_metadata and tag the v2 publish so receivers can filter via discover_v2_sources(compile_target_filter=…, required_compile_metadata=…). What lands: ConversionEntry gains two new fields with safe defaults: - compile_target: str = "hf_raw" - compile_metadata: dict[str, Any] = {} register(...) takes them as kwargs; existing call sites are unchanged. Mirrors the constants in modelexpress.shape_descriptors (Phase 3a) but without a hard import dep in either direction — both repos keep their own canonical string set. select_default_conversion is refactored to a table-driven design. The old if/else chain is replaced by _DEFAULT_RULES: list[(predicate, name)] which the resolver walks in order. Adding a new kernel = adding one row via register_default_rule(predicate, name) from the kernel's own module on import. A predicate that raises on a malformed config is treated as "doesn't match" and skipped, keeping the resolver robust to model-card weirdness without forcing every predicate to be defensive. The AutoConfig import is deferred into the function body so the registry loads without requiring `transformers` (the registry is imported by tests + tooling that have no HF download capability). Existing entries get their tags retroactively: - bf16_cast / fp32_cast: compile_target="hf_raw" - fp8_128x128: compile_target="deep_gemm_fp8" + metadata{block_size: [128,128], scale_layout:"blockwise", dtype:"e4m3"} New conversion: cutlass_fp8_e4m3_per_channel - One scalar scale per output row (vs DeepGemm's per-128x128-block). - 2D dispatch: (out, in) weight → (out,) scale. 3D dispatch: (E, out, in) stacked MoE → (E, out) scale. - compile_target="cutlass_fp8", compile_metadata={dtype:"e4m3", scale_layout:"per_channel", scale_axis:-1, activation_scheme: "dynamic"} — matches cutlass scaled_mm + vLLM's native FP8 path. - Two default-resolver predicates: * quant_method="fp8" + quant_format="cutlass" (explicit) * quant_method="fp8" + weight_block_size=None + activation_scheme="dynamic" (the vLLM-published convention) Both predicates run AFTER the deep-gemm rule, so models with block_size=[128,128] AND activation_scheme="dynamic" still resolve to fp8_128x128 (regression-tested). Per-channel helpers in trainer/models/fp8.py: - fp8_per_channel_quantize(weight) → (q_e4m3, scale_f32). Handles 2D and 3D via the same code path; reduction over the innermost axis. - fp8_per_channel_quantize_into(weight, out, sf) — writes into preallocated buffers, matches the convention of fp8_block_quantize. Tests: 19/19 green via direct-load + transformers stub. Categories: - Per-channel quantize: 2D shape, 3D shape, 1D rejected, bf16 dequant accuracy (≤5% median rel error), into-buffer write. - Registry: existing entries carry correct compile_target + compile_metadata, cutlass entry registered + listed, default-rule insert/append ordering works, unknown quant error message lists registered names. - select_default_conversion dispatch: no-quant → bf16, [128,128] blockwise → fp8_128x128, quant_format=cutlass → cutlass, no weight_block_size + dynamic → cutlass, deep-gemm wins when both rules match. - Conversion fn dispatch: 2D linear path correctness, 3D MoE path correctness, requires_scale=True enforced. Adding a sibling kernel (per-token cutlass, awq, gptq, mxfp4, …) is now one new module ~80 LOC: write the quant fn, register() it with appropriate compile_target/metadata, register_default_rule() with its HF-config predicate. Branches off PR PrimeIntellect-ai#2389 head 79ea824. Independent of the Phase 2 graduation PR — these can land in parallel.
…dule paths The v0.5.2 trainer image ships MX 0.3.0 where HeartbeatThread is at modelexpress.heartbeat. Newer MX (0.4+ on kavink/nemo_rl_moe) moved it under modelexpress.metadata.heartbeat as part of the metadata-module reorg. Tolerate both with a try/except at import time so the same source code works against either MX version. No behavior change at runtime — the class is identical between paths. Unit tests (11/11) still pass since the test fixture patches the HeartbeatThread symbol on the module post-import.
… (v0.7.x) Captures the empirical findings from baking PRs #1 and #2 into an ARM64 GB200 image and running it on the kavin namespace for 8+ hours on Qwen3-30B-A3B-Instruct-2507 with gsm8k. Documents three real surprises the unit tests didn't cover: 1. Dockerfile.cuda's `uv sync` is missing `--extra disagg`, so modelexpress isn't installed in stock images; inference workers crash at the first import. Shipped v0.7.1 as a one-line overlay that adds the extra until the upstream Dockerfile.cuda can be updated. 2. `LD_PRELOAD` path for libcudart.so.12 — v0.5.2 had /usr/local/cuda present in the final stage; v0.7.0 (built from upstream Dockerfile.cuda as-is) doesn't. The pip-installed wheel path (/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/) is the new canonical location. 3. The configmap monkeypatch (patch_nixl_mx.py) and Phase 2's source-baked fixes are complementary — they patch different layers (broadcast vs rendezvous-wait) and both should stay until PR #1 merges upstream. Build experience numbers: - v0.7.0 from-scratch ARM64 build under QEMU: 6h45min (uv sync 45m, flash-attn from source 3h45m). - v0.7.1 overlay on top of v0.7.0: ~3 min. Cluster observations from v0.5.2 + configmap monkeypatch (the runtime-patched path our PR #1 codifies into source): - 183 successful RL refit cycles in one 66-min uninterrupted window - Reward variance 0.5-1.0 across orchestrator steps (real learning) - Off-policy level = 0 throughout - Zero NIXL data-plane errors - Recurring orchestrator wait_for_all_peers_ready timeout (~once per 30-66 min) is the exact bug class Phase 2's rendezvous-level dedup eliminates Also notes seven RFC updates queued in pensieve/RL/PrimeRL/09_rfc_updates_needed.md, three of which are new from this build experience (disagg extra, LD_PRELOAD path, vLLM PR #43375 / Anyscale RDT positioning). Companion to the RFC at docs/proposals/post-pr2389-kernel-compile-plan.md.
…/3/4 upstream form vLLM published https://vllm.ai/blog/2026-05-28-native-rl-apis the same day, announcing a standardized WeightTransferEngine abstract base + 4-phase lifecycle (init / start / update / finish) + a pluggable WeightTransferEngineFactory.register_engine(...) extension point. This is the upstream integration seam that the in-tree MxRendezvous reimplementation in PR PrimeIntellect-ai#2389 and the worker_extension_cls injection in inference/vllm/worker/nixl_mx.py have been emulating. The cleanest form of all our Phase 2/3/4 work upstream is a single MxWeightTransferEngine adapter (~150-200 LOC) that subclasses WeightTransferEngine and wraps the existing MxV2RefitReceiver + MxV2TrainingPublisher. Three immediate consequences captured in §8: §8.1 — Phase 2/3/4 should be repackaged as MxWeightTransferEngine for upstream contribution; the existing patches stay correct, the packaging just becomes upstream-native. §8.2 — The blog credits Matej Sirovatka specifically. He's likely mid-flight on a native-APIs rewrite of prime-rl's nixl_mx broadcast. Ask him before pushing Phase 2 upstream; the work may retarget to the adapter path directly. §8.3 — Their validation was at 16x 8xH200, DPEP32, 256 GPUs total. That scale makes Phase 4's multi-source slice planning load-bearing (mixed-TP/EP is the common case), not optional. Validates the design direction and sets the next cluster validation target after the DP=4 kavin smoke. §8.4 — pause_generation(mode="keep") + two-phase DPEP pause are features we don't yet match. Keep mode unlocks true async RL; queue after Phase 2 lands. Updated follow-up list grows from 4 to 7 items, with the three new ones being: write MxWeightTransferEngine, adopt keep-mode pause in the orchestrator, and coordinate with Robert Shaw / the vLLM RL roadmap on the K8s-native weight transfer engine they mention as ongoing work (which describes MX itself, modulo who's driving the upstream PR).
…three docs
The three proposal docs now form a coherent set:
- post-pr2389-status-and-plan.md — executive summary; failure-class
to fix mapping; mermaid diagram
of the data + metadata planes;
Phase 0 unblock guidance
- post-pr2389-kernel-compile-plan.md — full RFC with phase-by-phase
design rationale (unchanged
except for cross-link header)
- build-notes-2026-05-28.md — operational findings from the
source-baked image build, plus
the vLLM native RL APIs reframe
in section 8
Each doc now has a header block linking to the other two so readers
can navigate based on intent (status vs design vs operational).
The status-and-plan doc is the natural entry point for someone coming
to the work cold; the RFC and build-notes are the deep dives.
…ress design Adds a single new weight-broadcast type that consolidates every post-PrimeIntellect-ai#2389 optimization into one config knob: weight_broadcast.type = "mx_v2" Coexists with the existing "nixl_mx" (PR PrimeIntellect-ai#2389) for migration; no behavior of "nixl_mx" is affected by this change. What's included --------------- Trainer side * src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py - NIXLMxV2WeightBroadcast — drop-in replacement for NIXLMxWeightBroadcast (PR PrimeIntellect-ai#2389) - Uses MxV2TrainingPublisher from the MX v2 fat clients - Heartbeat + freshest-per-rank dedup + same-rank routing baked in (Phase 2 — no more configmap monkeypatch) - Stamps every publish with compile_target + compile_metadata from the conversion registry (Phase 3a) — receivers can refuse mismatched layouts at discovery - Preserves prime-rl's trainer-side conversion + slot layout + HSDP barrier ordering unchanged - Per-step: slot.fill_from(state_dict) → publisher.add_tensor() ×N → publisher.publish(version=step) → publisher.mark_ready() Inference side * src/prime_rl/inference/vllm/worker/nixl_mx_v2.py - NIXLMxV2WeightUpdateWorker — pull-mode worker extension - Uses MxWeightTransferEngine (vLLM WeightTransferEngine adapter from MX PR PrimeIntellect-ai#349 — same shape as Anyscale RDT PR #43375) - Phase 3b receiver-side compile_target_filter — refuses incompatible bytes BEFORE RDMA - Tree fan-out via publish_self_as_replica=True (TensorHub pattern; receivers republish so newcomers pull from peers instead of trainer) - Surfaces per-cycle metrics (bytes/Gbps/discovery_seconds/ source_worker_rank) back through the RPC return value Config + selector * packages/prime-rl-configs/src/prime_rl/configs/trainer.py - New MxV2WeightBroadcastConfig with the Phase 2/3 knobs - Discriminated union extended; existing configs unchanged * src/prime_rl/trainer/rl/broadcast/__init__.py - Selector dispatches "mx_v2" to NIXLMxV2WeightBroadcast Inference server + orchestrator wiring * src/prime_rl/inference/vllm/server.py - WORKER_EXTENSION_CLS["mx_v2"] mapping - POST /init_nixl_mx_v2 (mirrors /init_nixl_mx) - POST /update_weights_v2 (per-cycle refit; returns metrics) * src/prime_rl/utils/client.py - init_nixl_mx_v2_broadcast() async helper - update_weights_v2() async helper that returns per-server metrics Image * Dockerfile.cuda.mx-v2 — overlay on v0.7.1-kavin-phase2-phase3: 1. `uv pip install` the MX PR PrimeIntellect-ai#349 branch (Phase 4 + engine) 2. COPY the 5 v2 prime-rl files 3. Smoke tests at build time (engine import, flash_attn ABI) * docs/proposals/image-build-mx-v2.md — build mechanics + A/B deployment plan RFC * docs/proposals/post-pr2389-mx-v2.md — full design doc: - Capability comparison table (nixl_mx vs mx_v2) - Module-by-module design - Migration plan (v0.x → v0.x+1 deprecation → v0.x+2 removal) - Validation matrix against PR PrimeIntellect-ai#2389 on the same workload - References to all related PRs (#1 Phase 2, #2 Phase 3, ai-dynamo/modelexpress#349, vLLM #43375, TensorHub paper, vLLM native RL APIs blog) What's NOT in this commit ------------------------- * Unit tests for the new prime-rl integration files (TODO) * Built + pushed image artifact (TODO — needs Docker buildx) * End-to-end cluster validation on Qwen3-30B-A3B (TODO — needs cluster booking + parallel deployment) * Deletion of "nixl_mx" code (intentional — coexist for ≥1 release) The 58 MX-side unit tests on PR PrimeIntellect-ai#349 already cover the v2 fat clients + engine adapter that this RFC consumes. The new tests TODO is for the thin prime-rl-side glue (~250 LOC across the 2 new files).
…log cleanup
Adds 24 unit tests covering the new weight_broadcast.type="mx_v2" path:
tests/unit/train/rl/test_nixl_mx_v2.py (10 tests)
tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py (6 tests)
tests/unit/inference/vllm/test_mx_v2_server_endpoints.py (3 + 5 gated)
What's covered
--------------
Trainer broadcast (NIXLMxV2WeightBroadcast):
* Construction doesn't eagerly initialize the publisher
* is_primary_hsdp_rank gates correctly for the 3 cases
(no-HSDP, HSDP-primary, HSDP-non-primary)
* lazy_init builds MxV2TrainingPublisher with the right
TrainerWorldLayout, mx_server_url, and model_name
* lazy_init is idempotent on repeated calls
* broadcast_weights threads compile_target + compile_metadata into
every publisher.add_tensor call when publish_compile_target=True
* broadcast_weights falls back to "hf_raw" when
publish_compile_target=False (back-compat default)
* broadcast_weights threads is_expert / expert_axis / owned_expert_ids
for MoE slots correctly
* Non-primary HSDP ranks skip publish entirely
* Each slot's fill_from is invoked with the resolved conversion
* shutdown() is idempotent
Inference worker (NIXLMxV2WeightUpdateWorker):
* init_nixl_mx_v2 constructs the right MxInitInfo and pins UCX rail
* publish_self_as_replica False propagates correctly
* update_weights_via_mx_v2 constructs the right MxUpdateInfo and
calls engine.receive_weights with the load_weights callback
* No compile_target_filter passes None (back-compat)
* _load_weights_batch forwards to raw_model.load_weights (HF→fused
name remap via vLLM's stacked_params_mapping)
* Metrics dict is well-formed even when engine.last_transfer_stats
is None (early-cycle robustness)
Server-side glue:
* WORKER_EXTENSION_CLS table has "mx_v2" entry pointing at
NIXLMxV2WeightUpdateWorker
* Existing nccl / filesystem / nixl_mx entries preserved
* Trainer-side selector __init__.py routes mx_v2 to the new broadcast
* 5 HTTP endpoint + orchestrator-client tests gated to CI (need full
prime-rl install — they skip locally cleanly with explanation)
Test pattern
-----------
Uses importlib.util.spec_from_file_location + sys.modules stubs so the
tests run anywhere torch + pytest is available (no need for the full
prime-rl venv install). Same pattern as the MX-side
test_vllm_weight_transfer.py tests on PR PrimeIntellect-ai#349.
Local result: 19 passed, 5 skipped, no failures.
Source fix
----------
nixl_mx_v2.py:_build_world_layout() was being called twice in lazy_init
(once for the publisher's world_layout arg, once for the log message).
Refactored to call once and bind to a local. Pure correctness +
efficiency cleanup, no behavior change.
…t + pushed
Two fixes to make the v0.7.2 overlay build cleanly on top of v0.7.1
baseline:
1. uv lives at /usr/local/bin/uv on the v0.7.1 image (per Dockerfile.cuda
line 31's UV_INSTALL_DIR), not /app/.venv/bin/uv. The venv also has
no pip installed by default. Updated the modelexpress install step
to use `/usr/local/bin/uv pip install --python /app/.venv/bin/python`
so uv targets the right interpreter.
2. The original smoke test asserted `import flash_attn.ops`, but the
v0.7.1 image only ships `flash-attn-cute` (the Cute kernels variant),
not the traditional `flash_attn.ops` API. vLLM 0.21.0 works fine on
this stack regardless. Replaced the strict `flash_attn.ops` check
with:
- import vllm + print version
- import MxV2WeightBroadcastConfig from prime_rl.configs.trainer
(catches PYTHONPATH / source-overlay issues at build time)
Image now built + pushed to:
nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2
digest sha256:068902bb1730005345bd7253b93d88e68d2776f01f2197d6d7927f4460e2a690
All 4 build-time smoke tests pass:
✅ MxWeightTransferEngine imports cleanly
✅ MxV2TrainingPublisher + MxV2RefitReceiver + TrainerWorldLayout import
✅ vllm 0.21.0 import
✅ MxV2WeightBroadcastConfig importable from prime_rl.configs.trainer
Next: cluster smoke test (one pod boot, verify imports + worker class
registration), then Phase F A/B vs PR PrimeIntellect-ai#2389 on Qwen3-30B-A3B.
…mage ready for E2E
Two issues found while smoke-testing v0.7.2-kavin-mx-v2 on the kavin cluster:
1. v0.7.1 baseline ships only flash-attn-cute (Cute kernels variant).
ring-flash-attn (transitively imported through
prime_rl.trainer.models.glm_moe_dsa) needs `flash_attn.flash_attn_interface`.
The v0.5.2 image (which the live kavin trainer uses) has a stub package
`flash_attn_stub-2.7.3` that synthesizes these imports as
NotImplementedError-raising stubs. v0.7.1 doesn't.
Fix: copy the stub's 2 Python files from the running v0.5.2 trainer pod
(via kubectl cp) into prime-rl source under scripts/flash_attn_stub/,
then COPY them into /app/.venv/.../flash_attn/ during image build.
This restores the import surface ring-flash-attn / glm_moe_dsa need;
actual calls raise (callers should use SDPA on ARM64 GB200).
2. The first build of v0.7.2 missed COPYing server.py and client.py.
server.py is where WORKER_EXTENSION_CLS["mx_v2"] is registered and the
/init_nixl_mx_v2 + /update_weights_v2 endpoints live; client.py is
where init_nixl_mx_v2_broadcast + update_weights_v2 async helpers live.
Without them the orchestrator can't reach the new code paths.
Fix: add the missing COPY lines + explicit smoke test
(`assert "mx_v2" in WORKER_EXTENSION_CLS`) so a future regression is
caught at build time.
Result: v0.7.2-kavin-mx-v2 @ sha256:dd84426e497f9f424cc95dbfea9e5167f99c8c262232759f38067602b5064233
All 4 build-time smoke tests + 7 cluster smoke tests pass:
Build:
✅ MxWeightTransferEngine import
✅ v2 fat clients import (MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout)
✅ vllm 0.21.0
✅ flash_attn stub usable (flash_attn_interface._flash_attn_forward importable)
✅ prime-rl mx_v2 surfaces all import OK (NIXLMxV2WeightBroadcast +
NIXLMxV2WeightUpdateWorker through full prime_rl import chain)
✅ WORKER_EXTENSION_CLS["mx_v2"] = prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker
Cluster:
✅ Engine adapter import on cluster
✅ v2 fat clients import on cluster
✅ mx_v2 worker extension import on cluster
✅ mx_v2 broadcast import on cluster
✅ WORKER_EXTENSION_CLS lookup returns correct class
✅ modelexpress-server.kavin.svc.cluster.local:8001 reachable
✅ nixl_cu12 import
Image is ready for Phase F (Qwen3-30B-A3B A/B vs PR PrimeIntellect-ai#2389 baseline).
Stub package is committed under scripts/flash_attn_stub/ as a workaround
for the v0.7.1 base image's missing flash-attn-stub. Once v0.8 baseline
ships with the stub baked in (or the ARM64 flash-attn build path is
unbroken), this overlay step can drop.
Three orchestrator-side changes to make weight_broadcast.type="mx_v2" work end-to-end: 1. orchestrator.py: add `elif type == "mx_v2"` branch to init code. Calls init_nixl_mx_v2_broadcast (POSTs /init_nixl_mx_v2 to every inference admin server). Does NOT create an orchestrator-side MxRendezvous — for mx_v2 the trainer is the only publisher and drives its own publish() + mark_ready() per step, so no orchestrator-trainer handshake is needed. 2. orchestrator.py: add "mx_v2" to the (nccl, nixl_mx) tuple in the "skip disk existence check" line. mx_v2 weights flow through NIXL, not the filesystem. 3. scheduler.py::_apply_policy_update: add the mx_v2 per-cycle path. Calls update_weights_v2(admin_clients, step=next_ckpt_step, ...) instead of the existing student_inference.update_weights(weights_path). The engine adapter's discovery + retry-until-deadline absorbs the gap between trainer publish and orchestrator poll. 4. scheduler.py: also skip the wait-for-trainer-INITIALIZING in the mx_v2 branch, since the trainer asynchronously marks READY after broadcast_weights and the engine handles discovery internally. Dockerfile.cuda.mx-v2 also picks up orchestrator.py + scheduler.py overlays so the v0.7.2 image contains the full integration. Image rebuilt and pushed: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2 sha256:ce3ca0135da099fa440841583b6660e996c08d1f0caf8e2591b615bd5bc777a0 This commit completes Phase E of the post-2389-mx-v2 plan; the image is now ready for Phase F (Qwen3-30B-A3B A/B vs PR PrimeIntellect-ai#2389 baseline on the kavin cluster).
…fig schemas
Three plumbing fixes that surfaced while bringing the v0.7.2 image up
end-to-end on the kavin cluster against Qwen3-30B-A3B:
1. configs/orchestrator.py: add MxV2WeightBroadcastConfig to the
orchestrator's WeightBroadcastConfig discriminated union (was missing
"mx_v2" tag, caused the orchestrator pod to error out with pydantic
"Input tag 'mx_v2' does not match any of the expected tags").
2. configs/inference.py: extend Literal[] for weight_broadcast.type to
include "mx_v2" (mirrors the orchestrator-side change so the inference
server can boot under the v2 type).
3. Dockerfile.cuda.mx-v2: expanded source-overlay layer to also COPY:
- src/prime_rl/inference/vllm/server.py (carries the new
WORKER_EXTENSION_CLS["mx_v2"] + /init_nixl_mx_v2 + /update_weights_v2
endpoints)
- src/prime_rl/utils/client.py (init_nixl_mx_v2_broadcast +
update_weights_v2 async helpers)
- src/prime_rl/orchestrator/orchestrator.py + scheduler.py (the
mx_v2 dispatch branches added in the previous commit)
- packages/prime-rl-configs/src/prime_rl/configs/{orchestrator,inference}.py
(the schema additions from this commit)
Plus: prime-rl-configs is an editable install pointing at the
/app/packages/ source path, BUT pydantic's compiled-once class
resolution at import time caches the AST. So I also mirror these
three config files into /app/.venv/lib/python3.12/site-packages/prime_rl/configs/
for paranoia / future-proofing if the editable install layer changes.
Cluster status after this commit
--------------------------------
nixl_mx baseline (PR PrimeIntellect-ai#2389 path, v0.5.2 image): running on Qwen3-30B-A3B
in kavin ns as the control workload (Matej's existing deployment).
mx_v2 (this branch, v0.7.2 image): all 3 pods (trainer, inference,
orchestrator) deployed alongside via prime-rl-mx-v2-* names with
output_dir=/output/run/run_mx_v2 to isolate from the baseline.
Trainer: FULLY BOOTED with mx_v2 broadcast initialized:
"Initializing weight broadcast (type='mx_v2'
host='modelexpress-server.kavin.svc.cluster.local'
same_rank_only=True dedup_freshest_per_rank=True
publish_compile_target=True compile_target_filter=None
publish_self_as_replica=True)"
Confirms all 5 Phase 2/3/4 knobs are wired through and the
trainer is in its loop publishing.
Inference: NIXLMxV2WeightUpdateWorker injected via vLLM
worker_extension_cls and its RPCs (init_nixl_mx_v2,
update_weights_via_mx_v2, _load_weights_batch) are registered:
"Injected <class 'prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker'>
into <class 'vllm.v1.worker.gpu_worker.Worker'> for extended
collective_rpc calls ['_load_weights_batch', 'init_nixl_mx_v2',
'update_weights_via_mx_v2']"
The vLLM 0.21 + Qwen3-30B-A3B + v0.7.x image combination hits
a JIT-compile dependency on nvcc for FlashInfer TRTLLM/CUTLASS
MoE kernels (v0.7.x dropped /usr/local/cuda; v0.5.2 still has it).
Worked around with a runtime patch in run_inference.sh that
rewrites vllm/.../oracle/unquantized.py to force the TRITON
backend (no nvcc needed; pre-built kernels). With FlashInfer
out of the way the inference pod's only remaining blocker is
Triton autotune time — Qwen3-30B-A3B × 4 EP workers × 128
experts × multiple kernels can take 20-30 min on first boot,
so VLLM_ENGINE_READY_TIMEOUT_S is bumped from the default 600s
to 3600s.
Image: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2
latest digest (build9): see /tmp/mx_v2_build*.log
Next session pick-up
--------------------
1. `tsh login` (current session expired mid-run).
2. `kubectl -n kavin delete pod prime-rl-mx-v2-inference-0 --grace-period=1`
to recycle with the 3600s timeout patched configmap.
3. Wait ~25-30 min for Triton autotune; orchestrator logs will switch
from "Inference server was not reached after Ns" to a successful
refit cycle log line ("[mx_v2] refit step=N metrics=...").
4. Collect 3-5 refit cycles; populate
pensieve/RL/PrimeRL/11_benchmark_results.md.
5. Run Phase G side-runs (elastic + filter-mismatch).
…fallback
Three runtime-validated fixes from end-to-end Qwen3-30B-A3B on the kavin
cluster:
1. NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2 (inference worker):
wrap engine.receive_weights in a retry-with-backoff loop. The
orchestrator polls /update_weights_v2 with step=N right after dispatch,
but the trainer publishes version=N asynchronously (optimizer step +
add_tensor loop). When discovery fires before the trainer marks
version=N READY in the catalog, receive_weights raises
'no source matches filters'. Retry until timeout_seconds elapses,
only treating discovery-empty errors as transient (transport errors
propagate immediately). Cluster log validation:
"[mx_v2] receive_weights attempt PrimeIntellect-ai#16 for step=1: transient miss
(...); retrying in 8.0s"
2. NIXLMxV2WeightBroadcast.broadcast_weights (trainer broadcast):
GatheredSlot's API is `convert(state_dict)`, not
`fill_from(state_dict, conversion)`. The conversion is baked in at
`from_spec` creation time. Cluster log validation:
"[mx_v2] publish step=1 tensors=531 compile_target=hf_raw
mx_source_id=2bc84f264d5dc8a9 elapsed=0.887s"
3. NIXLMxV2WeightBroadcast.lazy_init (trainer broadcast):
`select_default_conversion(model_name)` may return a plain string
('bf16_cast', 'fp8_pack', ...) on the conversion registry that's
shipped in v0.7.x. Older NemoRL v2 designs assumed an object with
.compile_target. Use getattr with `str(self._conversion)` fallback
so the log line doesn't crash with AttributeError. Cluster log
validation:
"[mx_v2] publisher initialized: rank=0
layout=fsdp:1,tp:1,pp:1,ep:1 compile_target=bf16_cast"
End-to-end state after this commit
-----------------------------------
The complete mx_v2 pipeline is functional on Qwen3-30B-A3B:
✅ Trainer (4 GPUs, FSDP + EP=2): publishes 531 tensors per rank @
step=1 in 0.887s, then 0.150s/step thereafter. All 4 worker ranks
get distinct mx_source_id values in the MX catalog.
✅ Inference (4 GPUs, DP=4, EP enabled, Triton MoE + TRITON_ATTN):
NIXLMxV2WeightUpdateWorker is injected via vLLM's worker_extension_cls;
RPCs `init_nixl_mx_v2`, `update_weights_via_mx_v2`,
`_load_weights_batch` are registered with collective_rpc.
✅ Orchestrator: rollouts succeed against the Qwen3-30B-A3B inference
serving — Reward=1.0000 on gsm8k step=0 (10.51s) + step=1 (1.60s).
Dispatches /update_weights_v2 to inference per scheduler cycle.
✅ Engine adapter (modelexpress.vllm_weight_transfer):
receive_weights discovers source by model_name + worker_rank +
compile_target_filter, then streams tensors through the
load_weights callback into vLLM's qwen3_moe loader.
⚠ Remaining issue: PrimeRL trainer's published QKV tensor shape
doesn't match vLLM's expected shape (assertion in
parameter.load_qkv_weight). Layout-translation gap between
prime-rl's FSDP+EP weight slot and vLLM's stacked QKV param. This
is the same general class of issue that PR PrimeIntellect-ai#2389 must address;
fixing it requires either applying the right shape transform on
the trainer-publish side (HF-format passthrough) or on the
inference-receive side (vLLM stacked_params_mapping in our
load_weights callback).
All v0.7.x cluster workarounds (also applied via runtime configmap
patches; should be promoted to image-level fixes in the next build):
- flash-attn ARM64 stub (already baked)
- vLLM MoE oracle: skip FlashInfer TRTLLM/CUTLASS (needs nvcc)
- VLLM_ATTENTION_BACKEND=TRITON_ATTN via vllm_extra
- VLLM_USE_DEEP_GEMM=0 + VLLM_DEEP_GEMM_WARMUP=skip
- classic_cuda_alloc: graceful no-op when JIT fails
…efit
Path A of the trainer↔vLLM format-translation work: translate PrimeRL's
TT-format slot keys + shapes into the HF-checkpoint names + per-tensor
shapes that vLLM's `load_weights` expects. With this in place vLLM's
own `stacked_params_mapping` (QKV / gate-up) and `expert_params_mapping`
(FusedMoE) handle the actual stacking into the model's stacked params —
the translator only undoes PrimeRL's publisher-side fusion.
Inference worker (`NIXLMxV2WeightUpdateWorker`):
* `_load_weights_batch` now runs `_translate_tt_to_hf(batch)` before
forwarding to `raw_model.load_weights`.
* New `_translate_tt_to_hf` handles 5 patterns for Qwen3-MoE family:
- fused `qkv_proj.weight` → split into 3 (q/k/v)
- fused dense `gate_up_proj.weight` → split into 2 (gate/up)
- `mlp.router.gate.weight` → rename to `mlp.gate.weight`
- stacked-expert `experts.w13_weight` → per-expert split into
gate_proj / up_proj with linear global expert IDs
(`my_rank * num_local + local_id`)
- stacked-expert `experts.w2_weight` → per-expert down_proj
- everything else passes through unchanged.
* `init_nixl_mx_v2` now probes `AutoConfig.from_pretrained(...)` for
the dims the translator needs (q_heads, kv_heads, head_dim,
num_experts) and the inference EP layout (DP × TP when
`enable_expert_parallel=True`), caching them under `_hf_config`.
* Translator is a no-op when `model_type` isn't qwen3_moe/qwen3 — safe
to layer onto non-MoE deployments.
Tests: 5 new unit tests in
`tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py`:
* QKV split with correct per-projection row counts (q=4096, k=512, v=512
for Qwen3-30B-A3B dims) + row-level data preservation.
* Router rename TT→HF.
* Stacked-expert w13 per-expert split with global-ID arithmetic on
multiple ranks (rank 0 → IDs 0..31; rank 2 → IDs 64..95 with
`ep_size=4, num_experts=128`).
* Stacked-expert w2 per-expert split.
* Passthrough for norms / o_proj / q-k_norm / embed / lm_head.
Worker test suite: 11/11 green locally.
Full mx_v2 suite: 24 passed, 5 skipped (CI-gated).
Test-only fix for the trainer-side test as a side effect: switch
`test_broadcast_weights_calls_slot_fill_from` → `test_*_calls_slot_convert`
since `GatheredSlot.convert(state_dict)` doesn't take the conversion
object (it's baked in at `from_spec` creation time). Matches the source
change shipped two commits ago.
Cluster status (runtime configmap-overlay versions of these same patches
are already deployed in kavin/prime-rl-mx-v2-*):
* trainer publishes 531 tensors per rank in TT-format
* inference receives, translates, and feeds vLLM
* the only remaining shape-failure was traced to the trainer's
ShardedSlot allocating 1/N of each non-expert tensor under FSDP+EP;
that's a *publisher-side* issue addressed in a separate commit
(kavin_pull_mode_gathered: force GatheredSlot for non-expert weights
so each rank publishes the full tensor via DTensor allgather, which
pull-mode + same-rank routing requires).
…t retry Two source-side companions to the TT→HF translator from the previous commit that close the remaining shape-mismatch + reconnect-race issues that surfaced when validating against Qwen3-30B-A3B on the kavin cluster: 1. NIXLMxV2WeightBroadcast.lazy_init: temporarily raise `slots.SMALL_NON_EXPERT_BYTES` to `1 << 60` for the duration of `model.build_slots(...)`, so every non-expert weight is built as a `GatheredSlot` (full tensor on each rank via DTensor.full_tensor()) instead of `ShardedSlot` (1/fsdp_total). Restore the threshold afterward so other code paths (e.g. nixl_mx push-mode broadcast running in the same process) aren't perturbed. Why: pull-mode + same-rank routing means each inference rank only contacts ONE trainer rank for the pull. ShardedSlot's 1/N FSDP shard would deliver only 1/N of the tensor to the receiver, and vLLM's `param.load_qkv_weight` (TP=1 case) refuses the shape with `assert param_data.shape == loaded_weight.shape`. Push-mode (PR PrimeIntellect-ai#2389) doesn't have this issue because each trainer rank writes its FSDP shard directly into the inference's pre-allocated buffer at its rank-specific offset, and all N senders contribute to the full tensor in the receiver's memory. Trade-off: extra `full_tensor()` allgather per non-expert tensor per refit. Measured at <50ms total on 4×GB200 NVL for Qwen3-30B-A3B (~4 GB of non-expert weights), well inside the per-cycle budget. The long-term replacement is Phase 4 multi-source slicing in the engine adapter (receivers pull *partial* tensors from *multiple* trainer ranks and assemble locally — same semantics as nixl_mx's push-mode, but receiver-driven). Until that lands, gather-first is the right trade. 2. NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2: expand the retry-transient set to include `NIXL_ERR_REMOTE_DISCONNECT`, `NIXL_ERR_NOT_ALLOWED`, and `NIXL_ERR_NOT_FOUND` (previously only discovery-empty errors retried). These three error codes correspond to trainer-pod restart races where the MX catalog still has the dead agent's metadata for a few seconds before the heartbeat timeout reaps it. Any other exception (real shape mismatch, real transport failure, anything not on the allowlist) continues to propagate immediately. Unit tests: 1 new `tests/unit/train/rl/test_nixl_mx_v2.py::test_lazy_init_forces_gathered_slots_for_pull_mode` that captures the threshold value `slots.SMALL_NON_EXPERT_BYTES` AT the moment `model.build_slots(...)` is called (verifying the escalation is active during slot construction) and that it's restored to the original value after lazy_init returns. Full suite: 25 passed, 5 skipped (CI-gated). Cluster validation status: * Trainer publishes Q/K/V as separate per-source slots at FULL shape ((4096, 2048) and (512, 2048)) — confirmed via injected `[KAVINDBG-PUB] buf_key='model.layers.0.self_attn.q_proj.weight' shape=(4096, 2048)` log line. * MoE experts publish as (32, 1536, 2048) for w13 and (32, 2048, 768) for w2 — 32 local experts per rank with ep=4 (matches inference EP=4). * Orchestrator gets real Reward != 1.0 rollouts on Qwen3-30B-A3B (e.g. `Reward: 0.5000`, `Reward: 0.6250` on gsm8k) — confirms the inference engine is actually serving from the pre-refit weights. * Engine discovery + load_weights callback dispatch confirmed. * Final E2E close-the-loop is the inference + trainer reconnect race during trainer pod restarts, which this commit's retry expansion absorbs. Steady-state cycles still pending.
…-fixes' into kavink/post-2389-mx-v2-combined
…ry-extensions' into kavink/post-2389-mx-v2-combined
…k/post-2389-mx-v2-combined
These docs walked through the post-PR PrimeIntellect-ai#2389 design journey: the kernel-compile plan, the image-build plan, the source-baked image build notes, the status-and-plan doc, and the consolidating mx_v2 RFC. Useful as a process record while the design was in flux, but not part of what should land upstream alongside the ``weight_broadcast.type = "mx_v2"`` code change. The upstream-facing docs that explain how to use the ``mx_v2`` broadcast type will land in a separate docs-only PR once the code path is reviewed and merged. Removed: docs/proposals/build-notes-2026-05-28.md docs/proposals/image-build-mx-v2.md docs/proposals/post-pr2389-kernel-compile-plan.md docs/proposals/post-pr2389-mx-v2.md docs/proposals/post-pr2389-status-and-plan.md Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacks the post-#2389 work into one branch, targeting
nixl_mxas the base so it lands cleanly on top:transport/mx_rendezvous.py+ tests): the same three runtime patches the [Weight Transfer] NIXL + MX Integration #2389 deployment runs today, baked into the tree: heartbeat-based liveness + same-rank-only peer filter + freshest-per-rank dedup onadd_remote_agent. Closes the multi-NIC L3-subnet pairing failures + the stale-worker_rankNIXL_ERR_NOT_ALLOWEDrace.trainer/models/conversions/): addscutlass_fp8(e4m3 per-channel) alongside the existingbf16_cast+fp8_blockwise, withcompile_target+compile_metadatatagging plumbed throughTensorDescriptorV2. Gives receivers a Phase 3b-ready handle for refusing mismatched kernel layouts at discovery time.mx_v2pull-mode rail (weight_broadcast.type = \"mx_v2\"): a sibling of thenixl_mxpush-mode rail. Same MX server, same NIXL data plane, same per-cycle HTTP poke — inverted so the inference side pulls via vLLM's nativeWeightTransferEnginecontract. Coexists withnixl_mx; both selectable per-run via the existingweight_broadcast.typediscriminator.Full design: see
docs/proposals/post-pr2389-mx-v2.md. The companion external-audience write-up that walks through the same design with mermaid diagrams (architecture, push-vs-pull sequence, TT→HF flow, coexistence) is attemp/PrimeRL_mx_v2_Design.mdin my workspace — happy to inline it into the PR description if that helps review.Dependency
This PR depends on ai-dynamo/modelexpress#349 (currently DRAFT) — which ships the MX v2 client library (
MxV2TrainingPublisher,MxV2RefitReceiver,MxWeightTransferEngine) + Phase 4 multi-source slice planner + cluster bug fixes. The Dockerfile here (Dockerfile.cuda.mx-v2) pulls modelexpress from that branch (@kavink/post-2389-phase3-4). Once #349 lands in a tagged MX release, that pin becomes a regular version bump.Why pull-mode on top of push-mode
The push-mode design in #2389 is the right shape for the per-cycle critical path on a single homogeneous RL deployment. Four things motivated a parallel pull-mode rail:
WeightTransferEngineas the receiver-side contract every framework is converging on (vLLM, NemoRL/Dynamo, verl, prime-rl). A pull engine is one contract, four consumers.publish_self_as_replica). Trainer egress stays 1× regardless of receiver count.None of these require deleting the push-mode path. The right shape is two co-existing rails, selected per-run. Nothing in #2389 changes;
mx_v2is additive.Coexistence story
Each layer (config schema, trainer selector, inference worker registry, orchestrator scheduler) is a discriminated union keyed off
config.type:nixl_mx(#2389)mx_v2(this PR)NIXLMxWeightBroadcast(ShardedSlot +transport_plan.prepare_writes)NIXLMxV2WeightBroadcast(GatheredSlot +publisher.add_tensor+publish())NIXLMxWeightUpdateWorker(pre-registered param buffers)NIXLMxV2WeightUpdateWorker(MxWeightTransferEngine.receive_weights)update_weights+ MxRendezvous status flipsupdate_weights_v2+ engine-driven discovery + pullA/B comparison is a one-line
type = ...flip.What's in this PR (in commits)
feat(transport/mx): Phase-2 — heartbeat + freshest-per-rank dedup + same-rank filterfix(transport/mx_rendezvous): tolerate both modelexpress.heartbeat module pathsfeat(conversions): cutlass FP8 e4m3 per-channel + compile_target/metadata taggingmx_v2RFC:RFC: weight_broadcast.type=\"mx_v2\" — the complete prime-rl × ModelExpress designfeat(orchestrator): wire mx_v2 into the per-cycle refit pathbuild(mx_v2): fix Dockerfile uv path + smoke tests,bake flash-attn ARM64 stub + complete source overlaybuild/configs(mx_v2): full image overlay + orchestrator/inference config schemasfix(mx_v2): worker retry loop + trainer slot API + conversion-as-str fallbackfeat(mx_v2): receiver-side TT→HF translator for Qwen3-MoE pull-mode refitfeat(mx_v2): trainer GatheredSlot escalation + receiver NIXL transient retryTests
25 new unit tests, 5 CI-gated:
tests/unit/transport/test_mx_rendezvous_phase2.pytests/unit/train/models/conversions/test_cutlass_fp8.pytests/unit/train/rl/test_nixl_mx_v2.py(10 tests)tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py(11 tests)tests/unit/inference/vllm/test_mx_v2_server_endpoints.py(3 + 5 gated)Local result: 25 passed, 5 skipped, no failures. The 5 skipped tests require the full prime-rl install for the FastAPI/httpx endpoint surface and run in CI.
Validation status
NIXLMxV2WeightUpdateWorker)Injected ... for extended collective_rpc calls ['_load_weights_batch', '_translate_tt_to_hf', 'init_nixl_mx_v2', 'update_weights_via_mx_v2']Reward: 1.0000/0.5000/0.6250on gsm8k/update_weights_v2round-tripcollective_rpc(\"update_weights_via_mx_v2\", ...)model_name + worker_rank + filtersdocs/proposals/post-pr2389-mx-v2.md); 20-min sync with #2389's owner on vLLM QKVParallelLinear narrow semantics will close itThe synthetic NIXL benchmarks (~30-50 GB/s on NVL, ~10ms catalog hit, retry backoff 0.5s→8.0s capped at
timeout_seconds) are in modelexpress#349.Open coordination questions
These are decisions where input from the #2389 owner would help before this lands ready-for-review (full list in §10 of
docs/proposals/post-pr2389-mx-v2.md):\"mx_v2\"the right tag, or do we want\"nixl_mx_pull\"to make the relationship to\"nixl_mx\"explicit?MxWeightTransferEnginelives inai-dynamo/modelexpress. Long-term, vLLM'sWeightTransferEngineABC is the right home, and MX-specificinit_info_cls/update_info_clsare the natural integration point the blog calls out. Confirm we're aligned with this trajectory.mx_v2-specific (flash-attn ARM64 stub, vLLM MoE oracle skipping FlashInfer TRTLLM/CUTLASS that needsnvcc,VLLM_USE_DEEP_GEMM=0, Triton MoE +TRITON_ATTN,classic_cuda_allocgraceful no-op when nvcc is missing). All pure additions to the base image. Happy to promote them into the base image build.mx_v2patchesSMALL_NON_EXPERT_BYTES = 1 << 60for the duration ofmodel.build_slots(...). Alternative is a per-broadcastnon_expert_layout = \"gathered\" | \"sharded\"flag onmodel.build_slots. Which interface do you prefer?mx_v2.1or replaces this GatheredSlot escalation.importlib.util.spec_from_file_location+sys.modules-stub pattern lets these tests run without a full prime-rl install. Same shape as the engine adapter tests on modelexpress#349. Happy to switch if you prefer a different convention.Test plan
ai-dynamo/modelexpress#349first so the modelexpress pin inDockerfile.cuda.mx-v2can move from@kavink/post-2389-phase3-4to a tagged releasenixl_mxbaselineRelated